Fix torchvision compatibility check for source builds and future torch versions#3978
Conversation
…h versions The torchvision version check raised a hard ImportError for custom/source-built PyTorch installations (e.g. AMD ROCm from source with +git* suffixes), even when the actual build was functional. This also silently skipped any torch version not already in the hardcoded table, giving no warning at all for future releases. Changes: - Detect custom/source builds by checking the raw version string's local identifier against known standard prefixes (cu, rocm, cpu, xpu). Our custom Version() strips local identifiers via regex, so detection must happen on the raw string before parsing. - Downgrade to a warning (instead of ImportError) for custom/source builds, since their version numbers may not follow standard PyPI release pairings. - Add formula-based inference for future torch versions not yet in the table. The torch->torchvision minor version formula (torch 2.x -> tv 0.(x+15)) has held for every release from torch 2.0 through 2.9. For formula-predicted versions, mismatches produce a warning rather than a hard error. - Add UNSLOTH_SKIP_TORCHVISION_CHECK=1 env var to skip the check entirely. - Wrap importlib_version and Version calls in try/except so broken metadata never crashes the import.
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request significantly improves the torchvision compatibility check by introducing detection for custom/source builds and adding forward compatibility for future torch versions using a formula-based approach. Instead of raising a hard ImportError for these cases, it now issues a warning, making the library more robust. It also adds an environment variable to skip the check entirely. My review focuses on improving the robustness of the new detection logic. I've identified a potential bug in the custom build detection regex and a minor improvement for handling the environment variable. Overall, these are great changes that will improve the user experience.
| local = raw_version_str.split("+", 1)[1] | ||
| if not local: | ||
| return False | ||
| return not re.match(r"(cu|rocm|cpu|xpu)", local) |
There was a problem hiding this comment.
The current regex r"(cu|rocm|cpu|xpu)" with re.match is too broad and can incorrectly classify custom builds as standard ones. For example, a version with a local identifier like +cu121_custom would be considered a standard build because re.match would successfully match the prefix "cu". This could lead to unexpected ImportErrors instead of warnings for such custom builds, which undermines a key goal of this PR.
To ensure that only standard local identifiers are matched, you should use re.fullmatch to match the entire string and a more specific regex.
| return not re.match(r"(cu|rocm|cpu|xpu)", local) | |
| return not re.fullmatch(r"(?:cu\d[\d.]*|rocm\d[\d.]*|cpu|xpu)", local) |
|
|
||
| def torchvision_compatibility_check(): | ||
| # Allow skipping via environment variable for custom environments | ||
| if os.environ.get("UNSLOTH_SKIP_TORCHVISION_CHECK", "0") in ("1", "true", "True"): |
There was a problem hiding this comment.
The check for the environment variable UNSLOTH_SKIP_TORCHVISION_CHECK is case-sensitive for "true". To make it more robust and align with common practices for boolean environment variables, consider converting the value to lowercase before the check.
| if os.environ.get("UNSLOTH_SKIP_TORCHVISION_CHECK", "0") in ("1", "true", "True"): | |
| if os.environ.get("UNSLOTH_SKIP_TORCHVISION_CHECK", "0").lower() in ("1", "true"): |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 99a06160fd
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Extract major.minor from the parsed version | ||
| torch_release = torch_v.release | ||
| if len(torch_release) < 2: | ||
| return | ||
| torch_major, torch_minor = torch_release[0], torch_release[1] |
There was a problem hiding this comment.
Allow pre-release/nightly pairs without false mismatch
This logic derives the required torchvision version purely from torch_v.release (e.g., 2.8.0.dev20240301 becomes major/minor 2,8) and later compares against a final release string like 0.23.0. For nightly/dev builds, torchvision is typically 0.23.0.dev…, which Version() ranks below 0.23.0, so the check will raise ImportError even when the matching nightly pair is installed. This is a regression for standard nightly builds with local tags like +cu121 (not treated as custom). Consider detecting pre-releases (torch_v.is_prerelease or tv_v.is_prerelease) and downgrading to a warning or skipping the strict >= comparison for nightly pairs.
Useful? React with 👍 / 👎.
…tion Fixes three edge cases found during review: 1. Regex precision: cu/xpu now require a trailing digit (cu\d, xpu\d) to avoid false negatives on suffixes like "+custom_build" that happen to start with "cu". cpu/xpu match as exact strings only. 2. Case insensitivity: added re.IGNORECASE so "+ROCM6.3" and "+CPU" are correctly recognized as standard builds rather than custom ones. 3. Pre-release detection: nightly/dev/alpha/beta/rc builds with standard CUDA/ROCm suffixes (e.g. "2.7.0.dev20250301+cu124") now produce a warning instead of a hard ImportError. These builds commonly have version mismatches that are expected during development.
for more information, see https://pre-commit.ci
…e-release 1. Switch re.match to re.fullmatch for the custom build regex so the entire local identifier must match. Fixes false negatives where suffixes like +cu124_custom were misclassified as standard because re.match only checked the start of the string. 2. Use .lower() for the UNSLOTH_SKIP_TORCHVISION_CHECK env var so any casing of "true" / "TRUE" / etc. is accepted. 3. Check torchvision_version_raw for pre-release tags in addition to torch_version_raw, so a stable torch paired with a nightly torchvision (e.g. 0.23.0.dev...) also gets a warning instead of a hard ImportError.
|
Addressed all three review comments in 133deda: Regex precision (gemini): Switched from Env var casing (gemini): Changed to Pre-release detection in torchvision (codex): Now checks both 384 standard-table backwards-compat tests still match old behavior exactly. All edge cases pass. |
for more information, see https://pre-commit.ci
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request significantly improves the torchvision compatibility check by adding support for source builds, pre-release versions, and future torch releases. The changes are well-structured, introducing graceful fallbacks with warnings instead of hard errors, which enhances user experience. The logic for detecting custom builds and inferring compatibility is robust. My review includes a couple of suggestions to further improve debuggability by adding logging to exception handlers, which currently fail silently. Overall, this is a high-quality contribution that makes the library more resilient.
| except Exception: | ||
| return |
There was a problem hiding this comment.
The try...except Exception block silently ignores errors when fetching package versions. While this prevents the application from crashing during import, it can hide underlying environment issues, such as a corrupted installation. Logging a warning here would provide valuable feedback to the user that the compatibility check was skipped and why.
| except Exception: | |
| return | |
| except Exception as e: | |
| logger.warning(f"Unsloth: Could not determine torch/torchvision versions, skipping compatibility check. Error: {e}") | |
| return |
| except Exception: | ||
| return |
There was a problem hiding this comment.
Similar to the previous block, this try...except silently fails when parsing version strings. This can make it difficult to diagnose problems related to malformed version numbers. Logging a warning that includes the problematic version strings would greatly improve debuggability for users.
except Exception as e:
logger.warning(
f"Unsloth: Could not parse torch/torchvision versions, skipping compatibility check. "
f"Versions: torch='{torch_version_raw}', torchvision='{torchvision_version_raw}'. Error: {e}"
)
return|
|
||
| # Detect nightly/dev/alpha/beta/rc builds from the raw version string. | ||
| # These often have version mismatches that are expected. | ||
| _pre_tags = (".dev", "a0", "b0", "rc", "alpha", "beta", "nightly") |
There was a problem hiding this comment.
The _pre_tags tuple is a constant collection of strings. It's a good practice to define such constants at the module level (e.g., as _TORCHVISION_PRE_RELEASE_TAGS) rather than inside a function. This improves readability and maintainability by making it clear that this is a fixed set of values used for the check.
…h versions (unslothai#3978) * Fix torchvision compatibility check for source builds and future torch versions The torchvision version check raised a hard ImportError for custom/source-built PyTorch installations (e.g. AMD ROCm from source with +git* suffixes), even when the actual build was functional. This also silently skipped any torch version not already in the hardcoded table, giving no warning at all for future releases. Changes: - Detect custom/source builds by checking the raw version string's local identifier against known standard prefixes (cu, rocm, cpu, xpu). Our custom Version() strips local identifiers via regex, so detection must happen on the raw string before parsing. - Downgrade to a warning (instead of ImportError) for custom/source builds, since their version numbers may not follow standard PyPI release pairings. - Add formula-based inference for future torch versions not yet in the table. The torch->torchvision minor version formula (torch 2.x -> tv 0.(x+15)) has held for every release from torch 2.0 through 2.9. For formula-predicted versions, mismatches produce a warning rather than a hard error. - Add UNSLOTH_SKIP_TORCHVISION_CHECK=1 env var to skip the check entirely. - Wrap importlib_version and Version calls in try/except so broken metadata never crashes the import. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address review: stricter regex, case insensitivity, pre-release detection Fixes three edge cases found during review: 1. Regex precision: cu/xpu now require a trailing digit (cu\d, xpu\d) to avoid false negatives on suffixes like "+custom_build" that happen to start with "cu". cpu/xpu match as exact strings only. 2. Case insensitivity: added re.IGNORECASE so "+ROCM6.3" and "+CPU" are correctly recognized as standard builds rather than custom ones. 3. Pre-release detection: nightly/dev/alpha/beta/rc builds with standard CUDA/ROCm suffixes (e.g. "2.7.0.dev20250301+cu124") now produce a warning instead of a hard ImportError. These builds commonly have version mismatches that are expected during development. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address PR review comments: fullmatch, env var casing, torchvision pre-release 1. Switch re.match to re.fullmatch for the custom build regex so the entire local identifier must match. Fixes false negatives where suffixes like +cu124_custom were misclassified as standard because re.match only checked the start of the string. 2. Use .lower() for the UNSLOTH_SKIP_TORCHVISION_CHECK env var so any casing of "true" / "TRUE" / etc. is accepted. 3. Check torchvision_version_raw for pre-release tags in addition to torch_version_raw, so a stable torch paired with a nightly torchvision (e.g. 0.23.0.dev...) also gets a warning instead of a hard ImportError. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary
torch==2.7.0+gitf717b2a,torchvision==0.21.0+7af6987),torchvision_compatibility_check()raised a hardImportErrorthat blockedfrom unsloth import FastLanguageModelentirely, even though the build was functional.Changes
Custom/source build detection -- Detects source builds by checking the raw version string's local identifier (the part after
+) against known standard prefixes (cu\d,rocm\d,cpu,xpu). This must operate on the raw string fromimportlib_version()because our customVersion()wrapper strips local identifiers via regex before parsing. The regex requirescu/rocmto be followed by a digit (avoiding false negatives on suffixes like+custom_build), is case-insensitive (handles+ROCM6.3), and matchescpu/xpuas exact strings.Pre-release/nightly detection -- Detects
.dev,a0,b0,rc,alpha,beta,nightlytags in the raw torch version string. Nightly/dev/rc builds with standard CUDA or ROCm suffixes (e.g.2.7.0.dev20250301+cu124) now produce a warning instead of a hardImportError, since these builds commonly have version mismatches that are expected during development.Formula-based forward compatibility -- The torch-to-torchvision minor version mapping follows a consistent formula that has held for every release from torch 1.7 through 2.9:
torch 1.x->torchvision 0.(x + 1)(verified: 1.7 through 1.13)torch 2.x->torchvision 0.(x + 15)(verified: 2.0 through 2.9)For versions not in the known table, the formula is used as a fallback. Mismatches from formula-predicted versions produce a warning rather than a hard error, since the formula could in theory change.
Graceful degradation -- Wraps
importlib_version()andVersion()calls in try/except so broken package metadata never crashes the import.Environment variable override --
UNSLOTH_SKIP_TORCHVISION_CHECK=1skips the check entirely for environments where the user knows the build is correct.Behavior matrix:
+cu124)+git*).dev*+cu124)UNSLOTH_SKIP_TORCHVISION_CHECK=1Validated with 384 standard-table backwards-compat tests (100% match with old behavior), plus 50+ additional tests covering custom builds, pre-releases, forward compat, edge cases, and adversarial inputs.
Test plan
+gitXXX,+HEXHASH) produce a warning instead of ImportErrorUNSLOTH_SKIP_TORCHVISION_CHECK=1bypasses the check+ROCM6.3,+CPU)+custom_build,+cust, etc.)